import time
import bisect
from multiprocessing.dummy import Pool
from threading import Thread
from collections import deque

from quant.utils import EventEngine, catch_exception, min_, max_, LimitDict, logging
from quant.markets import functions


class MarketInfoEngine(EventEngine):
    def subscribe_raw(self, callback):
        self.subscribe('RAW', callback)

    def unsubscribe_raw(self, handler):
        self.unsubscribe('RAW', handler)

    def push_raw(self, event, exchange, symbol, frequency, recv_time, raw):
        self.put('RAW', event, exchange, symbol, frequency, recv_time, raw)

    def subscribe_channel_recv(self, handler):
        self.subscribe('channel_recv', handler)

    def push_channel_recv(self, channel):
        self.put('channel_recv', channel)

    def subscribe_process_recv(self, handler):
        self.subscribe('process_recv', handler)

    def push_process_recv(self, data_id, server_time, recv_time):
        self.put('process_recv', data_id, server_time, recv_time)


class RoutingEngine(EventEngine):
    ALL_DATA = 'ALL_DATA'

    def push_data(self, routing_key, market_data):
        self.put(self.ALL_DATA, market_data)
        self.put(routing_key, market_data)


class Mapping:
    def __init__(self):
        self._dic = {}

    def __repr__(self):
        return self._dic.__repr__()

    def get(self, event, exchange, symbol, frequency):
        return self._dic[event][exchange][symbol][frequency]

    def get_default(self, event, exchange, symbol, frequency, default=None):
        try:
            return self._dic[event][exchange][symbol][frequency]
        except KeyError:
            return default

    def set(self, event, exchange, symbol, frequency, value):
        try:
            self._dic[event][exchange][symbol][frequency] = value
        except KeyError:
            self._dic.setdefault(event, {}).setdefault(exchange, {}).setdefault(symbol, {})[frequency] = value

    def set_default(self, event, exchange, symbol, frequency, value):
        try:
            result = self.get(event, exchange, symbol, frequency)
        except KeyError:
            self.set(event, exchange, symbol, frequency, value)
            result = value
        return result

    def values(self):
        for d0 in self._dic.values():
            for d1 in d0.values():
                for d2 in d1.values():
                    for d3 in d2.values():
                        yield from d3


class Mapping2:
    def __init__(self):
        self._dic = {}

    def __repr__(self):
        return self._dic.__repr__()

    def set(self, event, exchange, symbol, value):
        try:
            self._dic[event][exchange][symbol] = value
        except KeyError:
            self._dic.setdefault(event, {}).setdefault(exchange, {})[symbol] = value

    def get(self, event, exchange, symbol):
        return self._dic[event][exchange][symbol]

    def get_default(self, event, exchange, symbol, default):
        try:
            return self._dic[event][exchange][symbol]
        except KeyError:
            return default


class Mapping3:
    def __init__(self):
        self._dic = {}

    def __repr__(self):
        return self._dic.__repr__()

    def set(self, exchange, symbol, value):
        try:
            self._dic[exchange][symbol] = value
        except KeyError:
            self._dic.setdefault(exchange, {})[symbol] = value

    def get(self, exchange, symbol):
        return self._dic[exchange][symbol]

    def get_default(self, exchange, symbol, default=None):
        try:
            return self._dic[exchange][symbol]
        except KeyError:
            return default

    def items(self):
        for exchange, d1 in self._dic.items():
            for symbol, value in d1.items():
                yield exchange, symbol, value


class Timer:
    def __init__(self):
        self._last_step = 0
        self._subscribers = {}  # {func: interval, ...}
        self._call_time = {}  # {func: next_time, ...}
        self._caller = _Caller(self, precision=0.25)
        self._pool = Pool(10)
        self.now = None

    def step(self, timestamp):
        self.now = timestamp

        if timestamp - self._last_step < _timer_precision:
            return

        self._last_step = timestamp
        for func, next_time in self._call_time.items():
            if timestamp > next_time:
                # self._pool.apply_async(self._call, args=(func, timestamp, next_time))
                self._call(func, timestamp, next_time)

    def register(self, fn, interval):
        self._subscribers[fn] = interval
        self._call_time[fn] = 0

    def call_later(self, func, after):
        self._caller.call_later(func, after)

    def _call(self, func, timestamp, next_time):
        if next_time == 0:
            # try:
            #     func()
            # except:
            #     catch_exception()
            next_time = timestamp + self._subscribers[func]

        intv = self._subscribers[func]
        call_count = int((timestamp - next_time) // intv) + 1

        if call_count > 3:
            if func != self._caller.on_timer:
                pass
                # logging.warn('Timer._call({}) call_count is {}!'.format(func, call_count))

        for i in range(min_(call_count, 2)):
            try:
                t0 = time.time()
                func()
                block = time.time() - t0
                if block > 0.1:
                    logging.warn('timer block {}ms for {}'.format(block*1000, func))
            except:
                catch_exception()

        self._call_time[func] = next_time + call_count*intv


class _Caller:
    def __init__(self, timer, precision=0.25):
        self.timer = timer
        self._now = None

        self._call_times = []
        self._call_funcs = []
        self._suspended = []  # [(func, after), ...]

        timer.register(self.on_timer, precision)

    def on_timer(self):
        now = self.timer.now
        self._now = now

        if self._suspended:
            self._handle_suspended()

        times = self._call_times
        funcs = self._call_funcs

        while times:
            if times[0] <= now:
                times.pop(0)
                func = funcs.pop(0)
                try:
                    func()
                except:
                    catch_exception()
            else:
                break

    def call_later(self, func, after):
        if self._now is None:
            self._suspended.append((func, after))

        else:
            call_time = self._now + after
            index = bisect.bisect(self._call_times, call_time)

            self._call_times.insert(index, call_time)
            self._call_funcs.insert(index, func)

    def _handle_suspended(self):
        for func, after in self._suspended:
            self.call_later(func, after)
        self._suspended.clear()


class ChannelSelecting:
    def __init__(self, markets):
        from quant.markets.markets import Channel
        self._Channel = Channel

        self.markets = markets
        self.info_engine = SelectInfoEngine()

        self._compare_count = -1
        self._channel_selectors = {}  # {key: SingleChannelSelector(), ...}
        self._this_recv: tuple = None  # (channel, perf_counter)

        self._thread = Thread(target=self._run_select, daemon=True)

    def start(self, compare_count=1000):
        self._compare_count = compare_count
        self.markets.info_engine.subscribe_channel_recv(self._on_channel_recv)
        self.markets.info_engine.subscribe_process_recv(self._on_process_recv)
        self._thread.start()

    def tabulate_lag(self):
        from prettytable import PrettyTable

        rows = []
        for name, selector in self._channel_selectors.items():
            for _id, dq in selector._channel_lag_history.items():
                if len(dq) == 0:
                    mean = 'Null'
                else:
                    mean = sum(dq) / len(dq)
                    mean = '{:.2f}({})'.format(mean*1000, len(dq))
                rows.append((_id, name, mean))

        table = PrettyTable(['id', 'channel', 'lag'])
        table.add_rows(rows)
        return table

    def dict_server_time_lag_detail(self):
        result = {}
        for selector in self._channel_selectors.values():
            single_dict = selector.dict_server_to_local()
            for id_, lag in single_dict.items():
                ch = self._Channel.all_instance[id_]
                key = self._channel_key(ch)
                name = '{}({})'.format(key, id_)
                result[name] = lag
        return result

    def _run_select(self):
        while True:
            time.sleep(60)
            try:
                for selector in self._channel_selectors.values():
                    selector.select()
            except:
                catch_exception()

    def _on_channel_recv(self, channel):
        perf_counter = time.perf_counter()
        self._this_recv = (channel, perf_counter)

    def _on_process_recv(self, data_id, server_time, recv_time):
        if self._this_recv is None:
            return

        channel, perf_counter = self._this_recv
        key = self._channel_key(channel)
        selectors = self._channel_selectors

        if key in selectors:
            selector = selectors[key]
        else:
            selector = SingleChannelSelector(self.markets, self, self._compare_count)
            selectors[key] = selector

        server_to_local = recv_time - server_time if server_time else 0
        selector.record_lag(channel, data_id, perf_counter, server_to_local)

    def _channel_key(self, channel):
        return '{}_{}_{}_{}'.format(channel.event, channel.exchange, channel.symbol, channel.frequency)


class SingleChannelSelector:
    def __init__(self, markets, channel_selecting, compare_count):
        from quant.markets.markets import Channel
        self._Channel = Channel

        self.markets = markets
        self.channel_selecting = channel_selecting
        self.compare_count = compare_count

        self._channel_lag_history = {}  # {channel_id: deque(), ...}
        self._channel_server_to_local = {}  # {channel_id: deque(), ...}
        self._first_recv_time = LimitDict(200)  # {data_id: perf_counter, ...}

    def record_lag(self, channel, data_id, perf_counter, server_to_local):
        history_map = self._channel_lag_history
        history_map_2 = self._channel_server_to_local
        recv_time_map = self._first_recv_time

        _id = channel.channel_id
        if _id not in history_map:
            history_map[_id] = deque(maxlen=self.compare_count)
            history_map_2[_id] = deque(maxlen=self.compare_count)
        dq = history_map[_id]
        dq2 = history_map_2[_id]

        if data_id in recv_time_map:
            first = recv_time_map[data_id]
            lag = perf_counter - first

        else:
            recv_time_map[data_id] = perf_counter
            lag = 0

        dq.append(lag)
        dq2.append(server_to_local)

    def select(self):
        compare_count = self.compare_count
        lag_history = self._channel_lag_history

        with self.markets.lock:
            for dq in lag_history.values():
                if len(dq) < compare_count:
                    return

            lags = [(_id, sum(dq) / compare_count) for _id, dq in lag_history.items()]
            if len(lags) < 2:
                return

            lags.sort(key=self._sort_key)
            _id, lag = lags[-1]

            ch = self._Channel.all_instance[_id]
            if ch.is_open():
                ch.reconnect()
            self.channel_selecting.info_engine.put_reconnect(ch, lag)

            lag_history[_id].clear()

    def dict_server_to_local(self):
        result = {id_: sum(dq)/max_(len(dq), 1) for id_, dq in self._channel_server_to_local.items()}
        return result

    def _sort_key(self, x):
        return x[1]


class SelectInfoEngine(EventEngine):
    def subscribe_reconnect(self, handler):
        self.subscribe('Reconnect', handler)

    def put_reconnect(self, ch, lag):
        self.put('Reconnect', ch, lag)


def pick_handler(market, event, exchange, symbol, frequency):
    try:
        handler = market.handlers.get(event, exchange, symbol, frequency)
    except KeyError:
        from quant.exchanges import get_factory
        factory = get_factory(exchange)
        cls = factory.handler_cls
        handler = cls(event, exchange, symbol, frequency, market)
        market.handlers.set(event, exchange, symbol, frequency, handler)
    return handler


def set_value_with_book(markets):
    value_map = {}
    for exchange, symbol, book in markets.order_books.items():
        if book is None:
            continue

        mid = book.get_middle()
        if mid is None:
            continue

        asset, money = symbol.split('.')[0].split('/')
        value = mid * functions.get_asset_value(money)
        li = value_map.setdefault(asset, [])
        li.append(value)

    for asset, li in value_map.items():
        value = sum(li) / len(li)
        functions.set_asset_value(asset, value)


_timer_precision = 0.01


if __name__ == '__main__':
    from utils import build_test_book
    from quant.utils import set_test_mode
    functions.set_quick_mode()
    set_test_mode()

    m = Mapping3()

    def try_set_value_with_book():
        class Mar:
            order_books = m

        m.set('e1', 'btc/usdt', build_test_book(100))
        m.set('e1', 'eth/btc', build_test_book(0.11))
        m.set('e1', 'eth/usdt', build_test_book(11))

        set_value_with_book(Mar)
        set_value_with_book(Mar)
        set_value_with_book(Mar)
        set_value_with_book(Mar)
        print(functions.get_asset_value('btc'))
        print(functions.get_asset_value('eth'))

    tm = Timer()

    def fn():
        print('call at', tm.now)


    begin = 10100101

    tm.step(begin)
    begin += 1
    tm.step(begin)
    begin += 1
    tm.step(begin)
    begin += 1
    tm.step(begin)
    begin += 1

    print(begin)

    tm.call_later(fn, 4)
    tm.call_later(fn, 3)
    tm.call_later(fn, 2)


    tm.step(begin)
    begin += 1
    tm.step(begin)
    begin += 1
    tm.step(begin)
    begin += 1
    tm.step(begin)








